"""
Phase 1: Data Construction & Historical Classification
Development Threshold Paper

Classifies all countries by GDP/cap PPP trajectory through the $9k-$25k zone.
Computes crossing dates, transit times, and entry/exit characteristics.
"""

import pandas as pd
import numpy as np
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

# Resource rents
try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

LOWER = 9000
UPPER = 25000

# ═══════════════════════════════════════════════════════════════════════
# 1a. Classify all countries by threshold trajectory
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("1a. COUNTRY CLASSIFICATION BY THRESHOLD TRAJECTORY")
print("=" * 70)

# For robust classification, require at least 10 years of GDP data
countries_agg = panel[panel['gdp_pc_ppp'].notna()].groupby('iso3').agg(
    first_year=('year', 'min'),
    last_year=('year', 'max'),
    first_gdp=('gdp_pc_ppp', 'first'),
    last_gdp=('gdp_pc_ppp', 'last'),
    min_gdp=('gdp_pc_ppp', 'min'),
    max_gdp=('gdp_pc_ppp', 'max'),
    n_obs=('gdp_pc_ppp', 'count')
)

def classify(row):
    """Classify country trajectory through $9k-$25k zone."""
    ever_below = row['min_gdp'] < LOWER
    ever_above = row['max_gdp'] > UPPER
    ever_in_zone = (row['min_gdp'] <= UPPER) and (row['max_gdp'] >= LOWER)

    if row['first_gdp'] < LOWER and row['last_gdp'] > UPPER:
        return 'Crossed (below→above)'
    elif row['first_gdp'] < LOWER and LOWER <= row['last_gdp'] <= UPPER:
        return 'In zone (entered from below)'
    elif row['first_gdp'] < LOWER and row['last_gdp'] < LOWER:
        if row['max_gdp'] >= LOWER:
            return 'Entered zone then fell back'
        return 'Still below'
    elif LOWER <= row['first_gdp'] <= UPPER and row['last_gdp'] > UPPER:
        return 'Crossed (zone→above)'
    elif LOWER <= row['first_gdp'] <= UPPER and LOWER <= row['last_gdp'] <= UPPER:
        return 'Stuck in zone'
    elif LOWER <= row['first_gdp'] <= UPPER and row['last_gdp'] < LOWER:
        return 'Fell back below'
    elif row['first_gdp'] > UPPER:
        return 'Always above'
    else:
        return 'Other'

countries_agg['status'] = countries_agg.apply(classify, axis=1)

# Print classification
print("\nClassification summary:")
status_counts = countries_agg['status'].value_counts()
for status in ['Crossed (below→above)', 'Crossed (zone→above)', 'In zone (entered from below)',
               'Stuck in zone', 'Still below', 'Entered zone then fell back', 'Fell back below',
               'Always above', 'Other']:
    n = status_counts.get(status, 0)
    if n > 0:
        sub = countries_agg[countries_agg['status'] == status].sort_values('last_gdp', ascending=False)
        print(f"\n  {status}: {n} countries")
        if n <= 30:
            for iso, row in sub.iterrows():
                print(f"    {iso}: ${row['first_gdp']:,.0f} ({int(row['first_year'])}) → "
                      f"${row['last_gdp']:,.0f} ({int(row['last_year'])})")

# Save classification
countries_agg.to_csv('/mnt/c/demographics_capital_flows/development_threshold/data/country_classification.csv')

# ═══════════════════════════════════════════════════════════════════════
# 1b. Crossing dates and transit times
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("1b. CROSSING DATES AND TRANSIT TIMES")
print("=" * 70)

# All countries that ever entered the zone
ever_in_zone = panel[(panel['gdp_pc_ppp'] >= LOWER) & (panel['gdp_pc_ppp'] <= UPPER)]['iso3'].unique()
crossed = countries_agg[countries_agg['status'].str.startswith('Crossed')].index.tolist()

crossing_records = []
for iso in panel['iso3'].unique():
    c = panel[panel['iso3'] == iso].sort_values('year')
    gdp_s = c[c['gdp_pc_ppp'].notna()][['year', 'gdp_pc_ppp']]
    if len(gdp_s) < 5:
        continue

    # First year at or above $9k
    above_9k = gdp_s[gdp_s['gdp_pc_ppp'] >= LOWER]
    cross_9k = above_9k['year'].min() if len(above_9k) > 0 else np.nan

    # First year at or above $25k
    above_25k = gdp_s[gdp_s['gdp_pc_ppp'] >= UPPER]
    cross_25k = above_25k['year'].min() if len(above_25k) > 0 else np.nan

    # Time in zone
    if pd.notna(cross_9k) and pd.notna(cross_25k):
        years_in_zone = cross_25k - cross_9k
    elif pd.notna(cross_9k):
        # Censored: still in zone or fell back
        last_year = gdp_s['year'].max()
        years_in_zone_censored = last_year - cross_9k
        years_in_zone = np.nan  # True transit time unknown
    else:
        years_in_zone = np.nan

    # Within-zone dynamics
    if pd.notna(cross_9k):
        in_zone_data = c[(c['year'] >= cross_9k)]
        if pd.notna(cross_25k):
            in_zone_data = c[(c['year'] >= cross_9k) & (c['year'] <= cross_25k)]

        zone_growth = in_zone_data['rgdp_growth'].mean() if 'rgdp_growth' in c.columns else np.nan
        zone_growth_vol = in_zone_data['rgdp_growth'].std() if 'rgdp_growth' in c.columns else np.nan

        # Max drawdown in GDP/cap during transit
        gdp_in_zone = in_zone_data[in_zone_data['gdp_pc_ppp'].notna()]['gdp_pc_ppp']
        if len(gdp_in_zone) > 1:
            cummax = gdp_in_zone.cummax()
            drawdown = ((gdp_in_zone - cummax) / cummax).min()
        else:
            drawdown = 0
    else:
        zone_growth = np.nan
        zone_growth_vol = np.nan
        drawdown = np.nan

    # Entry characteristics
    entry_vars = {}
    if pd.notna(cross_9k):
        entry_row = c[c['year'] == cross_9k]
        if len(entry_row) > 0:
            for var in ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep', 'working_age_share',
                       'kaopen', 'ca_gdp', 'gross_savings_gdp', 'fiscal_bal_gdp',
                       'nfa_gdp', 'resource_rents_gdp', 'rgdp_growth', 'gdp_pc_ppp',
                       'trade_openness', 'life_expectancy', 'human_capital',
                       'gross_investment_gdp', 'inflation']:
                if var in entry_row.columns:
                    val = entry_row[var].values[0]
                    entry_vars[f'{var}_entry'] = val

    # Exit characteristics (for crossers)
    exit_vars = {}
    if pd.notna(cross_25k):
        exit_row = c[c['year'] == cross_25k]
        if len(exit_row) > 0:
            for var in ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep', 'working_age_share',
                       'kaopen', 'ca_gdp', 'gross_savings_gdp', 'fiscal_bal_gdp',
                       'nfa_gdp', 'resource_rents_gdp', 'rgdp_growth', 'gdp_pc_ppp']:
                if var in exit_row.columns:
                    exit_vars[f'{var}_exit'] = exit_row[var].values[0]

    # Demographic change during transit
    if pd.notna(cross_9k):
        z1_entry = entry_vars.get('Z_1_entry', np.nan)
        if pd.notna(cross_25k):
            z1_exit = exit_vars.get('Z_1_exit', np.nan)
        else:
            last_row = c[c['year'] == c['year'].max()]
            z1_exit = last_row['Z_1'].values[0] if len(last_row) > 0 else np.nan
        z1_change = z1_exit - z1_entry if pd.notna(z1_entry) and pd.notna(z1_exit) else np.nan
    else:
        z1_change = np.nan

    record = {
        'iso3': iso,
        'status': countries_agg.loc[iso, 'status'] if iso in countries_agg.index else 'Unknown',
        'cross_9k': cross_9k,
        'cross_25k': cross_25k,
        'years_in_zone': years_in_zone,
        'zone_growth_avg': zone_growth,
        'zone_growth_vol': zone_growth_vol,
        'max_drawdown': drawdown,
        'Z_1_change_in_zone': z1_change,
        'exited_above': 1 if iso in crossed else 0,
    }
    record.update(entry_vars)
    record.update(exit_vars)
    crossing_records.append(record)

cross_df = pd.DataFrame(crossing_records)
cross_df.to_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv', index=False)

# Print crossers
print("\nCountries that crossed $9k → $25k:")
crossers_df = cross_df[cross_df['exited_above'] == 1].sort_values('years_in_zone')
print(f"{'Country':6s} {'Enter $9k':>10s} {'Exit $25k':>10s} {'Years':>6s} "
      f"{'Z₁ entry':>9s} {'Z₁ exit':>8s} {'ΔZ₁':>6s} {'Growth':>7s} {'Drawdown':>9s}")
print("-" * 80)
for _, row in crossers_df.iterrows():
    z1e = f"{row.get('Z_1_entry', np.nan):.2f}" if pd.notna(row.get('Z_1_entry')) else '—'
    z1x = f"{row.get('Z_1_exit', np.nan):.2f}" if pd.notna(row.get('Z_1_exit')) else '—'
    dz1 = f"{row.get('Z_1_change_in_zone', np.nan):.2f}" if pd.notna(row.get('Z_1_change_in_zone')) else '—'
    gr = f"{row.get('zone_growth_avg', np.nan):.1f}%" if pd.notna(row.get('zone_growth_avg')) else '—'
    dd = f"{row.get('max_drawdown', np.nan)*100:.0f}%" if pd.notna(row.get('max_drawdown')) else '—'
    yrs = f"{row['years_in_zone']:.0f}" if pd.notna(row['years_in_zone']) else '—'
    c9 = f"{int(row['cross_9k'])}" if pd.notna(row['cross_9k']) else '?'
    c25 = f"{int(row['cross_25k'])}" if pd.notna(row['cross_25k']) else '?'
    print(f"{row['iso3']:6s} {c9:>10s} {c25:>10s} {yrs:>6s} {z1e:>9s} {z1x:>8s} {dz1:>6s} {gr:>7s} {dd:>9s}")

# ═══════════════════════════════════════════════════════════════════════
# 1c. Summary statistics by status group
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("1c. ENTRY CHARACTERISTICS BY STATUS GROUP")
print("=" * 70)

entry_vars_list = ['Z_1_entry', 'old_dep_entry', 'youth_dep_entry', 'working_age_share_entry',
                   'kaopen_entry', 'ca_gdp_entry', 'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry',
                   'resource_rents_gdp_entry', 'rgdp_growth_entry', 'gdp_pc_ppp_entry',
                   'trade_openness_entry', 'life_expectancy_entry', 'human_capital_entry',
                   'gross_investment_gdp_entry']

# Only for countries that entered the zone
zone_entrants = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

print(f"\nCountries that entered the zone: {len(zone_entrants)}")
print(f"  Exited above: {zone_entrants['exited_above'].sum()}")
print(f"  Did not exit: {(zone_entrants['exited_above'] == 0).sum()}")

for var in entry_vars_list:
    if var in zone_entrants.columns:
        c = zone_entrants[zone_entrants['exited_above'] == 1][var].dropna()
        nc = zone_entrants[zone_entrants['exited_above'] == 0][var].dropna()
        if len(c) > 2 and len(nc) > 2:
            from scipy import stats
            t, p = stats.ttest_ind(c, nc, equal_var=False)
            sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
            print(f"  {var:30s}  Crossers: {c.mean():8.3f}  Non-crossers: {nc.mean():8.3f}  "
                  f"Diff: {c.mean()-nc.mean():+8.3f}  p={p:.4f}{sig}")

# ═══════════════════════════════════════════════════════════════════════
# 1d. Within-zone dynamics
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("1d. WITHIN-ZONE DYNAMICS")
print("=" * 70)

c_growth = zone_entrants[zone_entrants['exited_above'] == 1]['zone_growth_avg'].dropna()
nc_growth = zone_entrants[zone_entrants['exited_above'] == 0]['zone_growth_avg'].dropna()
c_vol = zone_entrants[zone_entrants['exited_above'] == 1]['zone_growth_vol'].dropna()
nc_vol = zone_entrants[zone_entrants['exited_above'] == 0]['zone_growth_vol'].dropna()
c_dd = zone_entrants[zone_entrants['exited_above'] == 1]['max_drawdown'].dropna()
nc_dd = zone_entrants[zone_entrants['exited_above'] == 0]['max_drawdown'].dropna()

print(f"\nGrowth in zone:  Crossers mean={c_growth.mean():.2f}%  Non-crossers mean={nc_growth.mean():.2f}%")
print(f"Growth vol:      Crossers mean={c_vol.mean():.2f}%  Non-crossers mean={nc_vol.mean():.2f}%")
print(f"Max drawdown:    Crossers mean={c_dd.mean()*100:.1f}%  Non-crossers mean={nc_dd.mean()*100:.1f}%")

# Z₁ change during transit
c_dz1 = zone_entrants[zone_entrants['exited_above'] == 1]['Z_1_change_in_zone'].dropna()
nc_dz1 = zone_entrants[zone_entrants['exited_above'] == 0]['Z_1_change_in_zone'].dropna()
print(f"\nΔZ₁ during transit: Crossers mean={c_dz1.mean():.3f}  Non-crossers mean={nc_dz1.mean():.3f}")

# ═══════════════════════════════════════════════════════════════════════
# 1e. Output tables
# ═══════════════════════════════════════════════════════════════════════

# Table 1: Classification summary
table1 = countries_agg.groupby('status').agg(
    n_countries=('n_obs', 'count'),
    avg_first_gdp=('first_gdp', 'mean'),
    avg_last_gdp=('last_gdp', 'mean'),
).round(0)
table1.to_csv('/mnt/c/demographics_capital_flows/development_threshold/output/tables/table1_classification.csv')

# Table 2: Crossing details
crossers_out = cross_df[cross_df['exited_above'] == 1][
    ['iso3', 'cross_9k', 'cross_25k', 'years_in_zone', 'Z_1_entry', 'Z_1_exit',
     'Z_1_change_in_zone', 'old_dep_entry', 'kaopen_entry', 'resource_rents_gdp_entry',
     'zone_growth_avg', 'max_drawdown']
].sort_values('years_in_zone')
crossers_out.to_csv('/mnt/c/demographics_capital_flows/development_threshold/output/tables/table2_crossers.csv', index=False)

# Table 3: Entry characteristics comparison
from scipy import stats
comparison_rows = []
for var in entry_vars_list:
    if var in zone_entrants.columns:
        c_vals = zone_entrants[zone_entrants['exited_above'] == 1][var].dropna()
        nc_vals = zone_entrants[zone_entrants['exited_above'] == 0][var].dropna()
        if len(c_vals) > 2 and len(nc_vals) > 2:
            t, p = stats.ttest_ind(c_vals, nc_vals, equal_var=False)
            comparison_rows.append({
                'variable': var.replace('_entry', ''),
                'crossers_mean': c_vals.mean(),
                'crossers_sd': c_vals.std(),
                'non_crossers_mean': nc_vals.mean(),
                'non_crossers_sd': nc_vals.std(),
                'difference': c_vals.mean() - nc_vals.mean(),
                'p_value': p,
                'n_crossers': len(c_vals),
                'n_non_crossers': len(nc_vals),
            })

table3 = pd.DataFrame(comparison_rows)
table3.to_csv('/mnt/c/demographics_capital_flows/development_threshold/output/tables/table3_entry_comparison.csv', index=False)

print("\n✓ Phase 1 complete. Files saved to development_threshold/data/ and output/tables/")
